from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
    
    def forward(self, x):
        return x
    
    def print_model(self):
        print("self.device:", next(self.parameters()).device)


if __name__ == '__main__':
    model = Model()
    print("model.device:", next(model.parameters()).device)
    model.print_model()
